package cn.iocoder.yudao.module.bpm.util;

import cn.hutool.core.util.PageUtil;
import cn.hutool.db.DbUtil;
import cn.hutool.db.Page;
import cn.hutool.db.ds.simple.SimpleDataSource;
import cn.hutool.extra.spring.SpringUtil;
import com.baomidou.dynamic.datasource.DynamicRoutingDataSource;
import com.baomidou.mybatisplus.annotation.DbType;
import com.baomidou.mybatisplus.extension.toolkit.JdbcUtils;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;

import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.ResultSet;

/**
 * @title: 数据源工具类
 * @Author tzx
 * @Date: 2022/4/17 14:55
 * @Version 1.0
 */
@Slf4j
public class DatasourceUtil {


    /**
     * 获取主数据源
     *
     * @return
     */
    public static DataSource getDatasourceMaster() {
        DynamicRoutingDataSource dynamicRoutingDataSource = SpringUtil.getBean(DynamicRoutingDataSource.class);
        return dynamicRoutingDataSource.getDataSource("master");
    }

    /**
     * 获取数据源
     *
     * @return
     */
    public static DataSource getDataSource(String id) {
        DynamicRoutingDataSource dynamicRoutingDataSource = SpringUtil.getBean(DynamicRoutingDataSource.class);
        try {
            return dynamicRoutingDataSource.getDataSource(id);
        } catch (Exception e) {
            throw new RuntimeException("获取数据源失败： " + id);
        }
    }





    /**
     * 关闭链接
     *
     * @return
     */
    @SneakyThrows
    public static void close(Connection connection, ResultSet resultSet) {
        if (resultSet != null) {
            resultSet.close();
        }
        if (connection != null) {
            connection.close();
        }
    }

    /**
     * 构建分页查询sql
     *
     * @param sql
     * @param page
     * @return
     */
    public static String wrapperPageSql(String sql, Page page, DbType dbType) {
        switch (dbType) {
            case MYSQL:
                return wrapperPageSqlForMysql(sql, page);
            case ORACLE:
            case ORACLE_12C:
                return wrapperPageSqlForOracle(sql, page);
            case SQL_SERVER:
                return wrapperPageSqlForSqlServer(sql, page);
            case POSTGRE_SQL:
                return wrapperPageSqlForPostgreSql(sql, page);
            case DM:
                return wrapperPageSqlForDm(sql, page);
            case DB2:
                return wrapperPageSqlForDb2(sql, page);
            case KINGBASE_ES:
                return wrapperPageSqlForKingBaseEs(sql, page);
            case GAUSS:
                return wrapperPageSqlForGauss(sql, page);
            default:
                return sql;
        }
    }

    private static String wrapperPageSqlForMysql(String sql, Page page) {
        return "select * from (" + sql + ") t" + " limit " + PageUtil.getStart(page.getPageNumber(), page.getPageSize()) + "," +page.getPageSize();
    }

    private static String wrapperPageSqlForOracle(String sql, Page page) {
        return "select * from ( SELECT row_.*, rownum rownum_ from ( " + sql + " ) row_ where rownum <= " + PageUtil.getStart(page.getPageNumber(), page.getPageSize()) + ") table_alias where table_alias.rownum_ >" + PageUtil.getEnd(page.getPageNumber(), page.getPageSize());
    }

    private static String wrapperPageSqlForSqlServer(String sql, Page page) {
        return "select * FROM ( " + sql + " ) t ORDER BY current_timestamp offset " + PageUtil.getStart(page.getPageNumber(), page.getPageSize()) + " rows fetch next " + PageUtil.getEnd(page.getPageNumber(), page.getPageSize()) + " rows only";
    }

    private static String wrapperPageSqlForPostgreSql(String sql, Page page) {
        return "select * FROM ( " + sql + " ) t limit " + page.getPageNumber() + " offset " + page.getPageSize();
    }

    private static String wrapperPageSqlForDm(String sql, Page page) {
        return "select * from (" + sql + ") t" + " limit " + PageUtil.getStart(page.getPageNumber(), page.getPageSize()) + "," + PageUtil.getEnd(page.getPageNumber(), page.getPageSize());
    }

    private static String wrapperPageSqlForDb2(String sql, Page page) {
        return "select row_num() over() as rownum,* from ( " + sql + " ) where rownum > " + PageUtil.getStart(page.getPageNumber(), page.getPageSize()) + " and rownum < " + PageUtil.getEnd(page.getPageNumber(), page.getPageSize());
    }

    private static String wrapperPageSqlForKingBaseEs(String sql, Page page) {
        return "select * FROM ( " + sql + " ) t limit " + page.getPageNumber() + " offset " + page.getPageSize();
    }

    private static String wrapperPageSqlForGauss(String sql, Page page) {
        return "select * FROM ( " + sql + " ) t limit " + page.getPageNumber() + " offset " + page.getPageSize();
    }

    private static String wrapperPageCount(String sql, Page page) {
        return "select count(*) from (" + sql + ") t";
    }

    /**
     * 构建测试sql
     *
     * @return
     */
    public static String wrapperTestSql(DbType dbType) {
        switch (dbType) {
            case ORACLE:
            case ORACLE_12C:
                return "SELECT 1 FROM DUAL";
            default:
                return "select 1";
        }
    }

    /**
     * 构建测试sql
     *
     * @return
     */
    public static String wrapperDropSql(String tableName) {
        return "drop table " + tableName;
    }



    /**
     * 构建 判断表是否存在的sql
     * @param tableName
     * @param dbType
     * @return
     */
    private static String wrapperExistTableSql(String tableName, DbType dbType) {
        switch (dbType) {
            case ORACLE:
            case ORACLE_12C:
                return wrapperExistTableSqlForOracle(tableName);
            case SQL_SERVER:
                return wrapperExistTableSqlForSqlServer(tableName);
            case POSTGRE_SQL:
                return wrapperExistTableSqlForPostgreSql(tableName);
            case DM:
                return wrapperExistTableSqlForDm(tableName);
            case DB2:
                return wrapperExistTableSqlForDb2(tableName);
            case KINGBASE_ES:
                return wrapperExistTableSqlForKingBaseEs(tableName);
            case GAUSS:
                return wrapperExistTableSqlForGauss(tableName);
            default:
                return wrapperExistTableSqlForMysql(tableName);
        }
    }

    private static String wrapperExistTableSqlForMysql(String tableName) {
        return "select count(*) from information_schema.TABLES where TABLE_NAME = '" + tableName + "'";
    }

    private static String wrapperExistTableSqlForOracle(String tableName) {
        return "select count(*) from user_tables where table_name = '" + tableName + "'";
    }

    private static String wrapperExistTableSqlForSqlServer(String tableName) {
        return "select count(*) from sys.objects where object_id = OBJECT_ID('" + tableName + "')";
    }

    private static String wrapperExistTableSqlForPostgreSql(String tableName) {
        return "select count(*) from pg_catalog.pg_tables where tablename = '" + tableName + "'";
    }

    private static String wrapperExistTableSqlForDm(String tableName) {
        return "select count(*) from sysobjects where NAME = '" + tableName + "'";
    }

    private static String wrapperExistTableSqlForDb2(String tableName) {
        return "select count(*) from sysibm.systables where tablename = '" + tableName + "'";
    }

    private static String wrapperExistTableSqlForKingBaseEs(String tableName) {
        return "select count(*) from sys_tables where table_name = '" + tableName + "'";
    }

    private static String wrapperExistTableSqlForGauss(String tableName) {
        return "select count(*) from sys_tables where table_name = '" + tableName + "'";
    }

    public static Boolean testConnection(String url, String userName, String password) {
        Connection conn = null;
        log.info("测试数据库连接，url: " + url);
        try (SimpleDataSource ds = new SimpleDataSource(url, userName, password)) {
            ds.setLoginTimeout(50);
            conn = ds.getConnection();
        } catch (Exception e) {
            log.error("数据库连接失败！url: " + url, e);
            return false;
        } finally {
            DbUtil.close(conn);
        }
        return true;
    }



    public static DbType getDbType(String datasourceId) {
        DataSource dataSource = getDataSource(datasourceId);
        if (dataSource != null) {
            Connection connection = null;
            try {
                connection = dataSource.getConnection();
                return JdbcUtils.getDbType(connection.getMetaData().getURL());
            } catch (Exception e) {
                log.error("获取数据库类型失败！", e);
            } finally {
                DbUtil.close(connection);
            }
        }
        return null;
    }
}
